/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.ltr;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.search.IncompleteRerankingException;
import org.apache.solr.search.QueryLimits;

/**
 * Implements the rescoring logic. The top documents returned by solr with their original scores,
 * will be processed by a {@link LTRScoringQuery} that will assign a new score to each document. The
 * top documents will be resorted based on the new score.
 */
public class LTRRescorer extends Rescorer {

  private final LTRScoringQuery scoringQuery;

  public LTRRescorer() {
    this.scoringQuery = null;
  }

  public LTRRescorer(LTRScoringQuery scoringQuery) {
    this.scoringQuery = scoringQuery;
  }

  private static final Comparator<ScoreDoc> docComparator = Comparator.comparingInt(a -> a.doc);

  protected static final Comparator<ScoreDoc> scoreComparator =
      (a, b) -> {
        // Sort by score descending, then docID ascending:
        if (a.score > b.score) {
          return -1;
        } else if (a.score < b.score) {
          return 1;
        } else {
          // This subtraction can't overflow int
          // because docIDs are >= 0:
          return a.doc - b.doc;
        }
      };

  protected static void heapAdjust(ScoreDoc[] hits, int size, int root) {
    final ScoreDoc doc = hits[root];
    final float score = doc.score;
    int i = root;
    while (i <= ((size >> 1) - 1)) {
      final int lchild = (i << 1) + 1;
      final ScoreDoc ldoc = hits[lchild];
      final float lscore = ldoc.score;
      float rscore = Float.MAX_VALUE;
      final int rchild = (i << 1) + 2;
      ScoreDoc rdoc = null;
      if (rchild < size) {
        rdoc = hits[rchild];
        rscore = rdoc.score;
      }
      if (lscore < score) {
        if (rscore < lscore) {
          hits[i] = rdoc;
          hits[rchild] = doc;
          i = rchild;
        } else {
          hits[i] = ldoc;
          hits[lchild] = doc;
          i = lchild;
        }
      } else if (rscore < score) {
        hits[i] = rdoc;
        hits[rchild] = doc;
        i = rchild;
      } else {
        return;
      }
    }
  }

  protected static void heapify(ScoreDoc[] hits, int size) {
    for (int i = (size >> 1) - 1; i >= 0; i--) {
      heapAdjust(hits, size, i);
    }
  }

  /**
   * rescores the documents:
   *
   * @param searcher current IndexSearcher
   * @param firstPassTopDocs documents to rerank;
   * @param docsToRerank documents to return;
   */
  @Override
  public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int docsToRerank)
      throws IOException {
    if ((docsToRerank == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
      return firstPassTopDocs;
    }
    final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs);
    docsToRerank = Math.toIntExact(Math.min(docsToRerank, firstPassTopDocs.totalHits.value()));

    final ScoreDoc[] reranked = rerank(searcher, docsToRerank, firstPassResults);

    return new TopDocs(firstPassTopDocs.totalHits, reranked);
  }

  private ScoreDoc[] rerank(IndexSearcher searcher, int docsToRerank, ScoreDoc[] firstPassResults)
      throws IOException {
    final ScoreDoc[] reranked = new ScoreDoc[docsToRerank];
    final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
    final LTRScoringQuery.ModelWeight modelWeight =
        (LTRScoringQuery.ModelWeight)
            searcher.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1);

    scoreFeatures(docsToRerank, modelWeight, firstPassResults, leaves, reranked);
    // Must sort all documents that we reranked, and then select the top
    Arrays.sort(reranked, scoreComparator);
    return reranked;
  }

  protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) {
    final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
    Arrays.sort(hits, docComparator);

    assert firstPassTopDocs.totalHits.relation() == TotalHits.Relation.EQUAL_TO;
    return hits;
  }

  public void scoreFeatures(
      int docsToRerank,
      LTRScoringQuery.ModelWeight modelWeight,
      ScoreDoc[] hits,
      List<LeafReaderContext> leaves,
      ScoreDoc[] reranked)
      throws IOException {

    int readerUpto = -1;
    int endDoc = 0;
    int docBase = 0;

    LTRScoringQuery.ModelWeight.ModelScorer scorer = null;
    int hitPosition = 0;

    while (hitPosition < hits.length) {
      final ScoreDoc hit = hits[hitPosition];
      LeafReaderContext readerContext = null;
      while (hit.doc >= endDoc) {
        readerUpto++;
        readerContext = leaves.get(readerUpto);
        endDoc = readerContext.docBase + readerContext.reader().maxDoc();
      }
      // We advanced to another segment
      if (readerContext != null) {
        docBase = readerContext.docBase;
        scorer = modelWeight.modelScorer(readerContext);
      }
      scoreSingleHit(docsToRerank, docBase, hitPosition, hit, scorer, reranked);
      hitPosition++;
    }
  }

  /** Scores a single document. */
  protected void scoreSingleHit(
      int docsToRerank,
      int docBase,
      int hitPosition,
      ScoreDoc hit,
      LTRScoringQuery.ModelWeight.ModelScorer scorer,
      ScoreDoc[] reranked)
      throws IOException {
    /*
     * Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to call
     * score even if no feature scorers match, since a model might use that info to return a
     * non-zero score. Same applies for the case of advancing a
     * LTRScoringQuery.ModelWeight.ModelScorer past the target doc since the model algorithm still
     * needs to compute a potentially non-zero score from blank features.
     */
    assert (scorer != null);
    final int targetDoc = hit.doc - docBase;
    scorer.docID();
    scorer.iterator().advance(targetDoc);

    scorer.getDocInfo().setOriginalDocScore(hit.score);
    scorer.getDocInfo().setOriginalDocId(hit.doc);
    hit.score = scorer.score();
    if (QueryLimits.getCurrentLimits()
        .maybeExitWithPartialResults(
            "Learning To Rank rescoring -"
                + " The full reranking didn't complete."
                + " If partial results are tolerated the reranking got reverted and all documents preserved their original score and ranking.")) {
      throw new IncompleteRerankingException();
    }
    if (hitPosition < docsToRerank) {
      reranked[hitPosition] = hit;
    } else if (hitPosition == docsToRerank) {
      // collected topN document, I create the heap
      heapify(reranked, docsToRerank);
    }
    if (hitPosition >= docsToRerank) {
      // once that heap is ready, if the score of this document is greater that
      // the minimum I replace it with the
      // minimum and fix the heap.
      if (hit.score > reranked[0].score) {
        reranked[0] = hit;
        heapAdjust(reranked, docsToRerank, 0);
      }
    }
  }

  @Override
  public Explanation explain(IndexSearcher searcher, Explanation firstPassExplanation, int docID)
      throws IOException {
    return getExplanation(searcher, docID, scoringQuery);
  }

  protected static Explanation getExplanation(
      IndexSearcher searcher, int docID, LTRScoringQuery rerankingQuery) throws IOException {
    final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext().leaves();
    final int n = ReaderUtil.subIndex(docID, leafContexts);
    final LeafReaderContext context = leafContexts.get(n);
    final int deBasedDoc = docID - context.docBase;
    final Weight rankingWeight;
    if (rerankingQuery instanceof OriginalRankingLTRScoringQuery) {
      rankingWeight =
          rerankingQuery.getOriginalQuery().createWeight(searcher, ScoreMode.COMPLETE, 1);
    } else {
      rankingWeight =
          searcher.createWeight(searcher.rewrite(rerankingQuery), ScoreMode.COMPLETE, 1);
    }
    return rankingWeight.explain(context, deBasedDoc);
  }
}
